Workshops/[Gemma_1]Self_extend.ipynb (563 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "60KmTK7o6ppd"
},
"source": [
"##### Copyright 2024 Google LLC."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "GoAJlrr6LUyj"
},
"outputs": [],
"source": [
"# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a-xjCcnrf10b"
},
"source": [
"# Gemma - self extend Gemma\n",
"\n",
"This is one of the accompanying notebooks for the [Large Language Models with Keras](https://www.youtube.com/watch?v=TV7qCk1dBWA) technical session at Google I/O 2024 and it demonstrates how extend the context window of Gemma.\n",
"\n",
"<table align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Workshops/[Gemma_1]Self_extend.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2hMRutWLgMe7"
},
"source": [
"# Setup\n",
"\n",
"### Select the Colab runtime\n",
"To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n",
"\n",
"1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
"2. Select **Change runtime type**.\n",
"3. Under **Hardware accelerator**, select **L4 GPU** or **A100 GPU**."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AO7a1Q4Yyc9Z"
},
"source": [
"# Installation\n",
"\n",
"Install Keras and KerasNLP with the Gemma model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GS7Kjiw58X9T"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m19.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m590.6/590.6 MB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m58.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m66.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m57.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m60.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m311.2/311.2 kB\u001b[0m \u001b[31m32.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Building wheel for keras-nlp (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"tf-keras 2.15.1 requires tensorflow<2.16,>=2.15, but you have tensorflow 2.16.2 which is incompatible.\u001b[0m\u001b[31m\n",
"\u001b[0m"
]
}
],
"source": [
"!pip install -q -U git+https://github.com/keras-team/keras-nlp.git"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WEgg_OVIL2HY"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
"\n",
"import keras\n",
"import keras_nlp\n",
"\n",
"# for reproducibility\n",
"keras.utils.set_random_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "h5H7oE9MtnBA"
},
"outputs": [],
"source": [
"# Access to Gemma checkpoints on Kaggle\n",
"\n",
"# On Colab, set up you Kaggle credentials as Colab secrets and use this:\n",
"\n",
"from google.colab import userdata\n",
"\n",
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get(\"KAGGLE_USERNAME\")\n",
"os.environ[\"KAGGLE_KEY\"] = userdata.get(\"KAGGLE_KEY\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wo1xkzr62hXN"
},
"source": [
"# Load the 2B model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dTULEQ29qidE"
},
"source": [
" Load the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bu48vUnbQj0p"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Attaching 'model.safetensors' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'model.safetensors.index.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'task.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'model.safetensors' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'model.safetensors.index.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'model.safetensors' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'model.safetensors.index.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'preprocessor.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
"Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n"
]
}
],
"source": [
"gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma_1.1_instruct_2b_en\")\n",
"gemma_lm.compile(sampler=keras_nlp.samplers.GreedySampler()) # this is the default"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mS7jstFSlGIO"
},
"source": [
"# Add car dealership data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c3uuWiQ6HmH8"
},
"outputs": [],
"source": [
"retrieval_car = \"0. 2024 Google IO Car: Full of interesting demos and talks. Runs on AI, pretty good mileage. Ask for a discount at the dealership with this number 1234123.\"\n",
"car_dealership_data = \"\"\"\n",
"1. 2019 Chevrolet Silverado 1500 LT, Black: This workhorse is equipped with a powerful V8 engine and a spacious crew cab,offering a blend of power and comfort. Whether you're hauling heavy loads or cruising down the highway, this truck delivers a smooth and confident ride.\n",
"2. 2021 Toyota Camry SE, Silver: The Toyota Camry is known for its reliability and fuel efficiency, making it a practical choice for daily commuting. This particular SE trim adds a sportier flair with its sleek exterior design and responsive handling.\n",
"3. 2020 Ford F-150 XLT, Blue: The Ford F-150 is a perennial favorite in the truck market, and this XLT model features a comfortable interior and a variety of useful tech features. With its powerful engine and capable towing capacity, it's ready for work or play.\n",
"4. 2022 Honda Civic LX, White: The Honda Civic is a popular choice for those seeking a reliable and efficient sedan. This LX model is equipped with essential features and offers a comfortable and practical driving experience.\n",
"5. 2023 Nissan Altima SR, Red: The Nissan Altima SR boasts a sporty appearance and a peppy engine, making it a fun-to-drive sedan. With its modern interior and available technology features, it offers both style and substance.\n",
"6. 2018 Jeep Grand Cherokee Laredo, Green: The Jeep Grand Cherokee is known for its ruggedness and off-road capability, making it a popular choice for adventure seekers. This Laredo model offers a comfortable ride and plenty of cargo space for all your gear.\n",
"7. 2020 Ram 1500 Big Horn, Gray: The Ram 1500 Big Horn is a well-equipped truck with a powerful engine and a spacious cabin. It's a versatile workhorse that's also comfortable for everyday driving.\n",
"8. 2019 Subaru Outback Limited, Brown: The Subaru Outback is a popular choice for those seeking a vehicle with all-wheel drive and plenty of cargo space. This Limited model features a luxurious interior and a variety of advanced safety features.\n",
"9. 2021 Hyundai Tucson SEL, Silver: The Hyundai Tucson is a stylish and practical compact SUV with a comfortable interior and a smooth ride. This SEL model is equipped with a variety of modern features to enhance your driving experience.\n",
"10. 2023 Kia Sportage EX, Blue: The Kia Sportage is a popular SUV with a stylish exterior and a comfortable interior. This EX model is equipped with a range of features that make it a great choice for families or adventurers.\n",
"11. 2017 Ford Escape SE, White: The Ford Escape is a versatile compact SUV with a comfortable ride and a spacious interior. This SE model is equipped with a range of features that make it a great choice for families or adventurers.\n",
"12. 2022 Chevrolet Equinox LT, Red: The Chevrolet Equinox is a popular SUV with a stylish exterior and a comfortable interior. This LT model is equipped with a range of features that make it a great choice for families or adventurers.\n",
"13. 2020 Toyota RAV4 XLE, Black: The Toyota RAV4 is a reliable and fuel-efficient SUV that's popular for its practicality and versatility. This XLE model is equipped with a variety of features to enhance your driving experience.\n",
"14. 2018 Honda CR-V EX-L, Gray: The Honda CR-V is a perennial favorite in the SUV market, known for its reliability and comfortable ride. This EX-L model offers a luxurious interior and a range of advanced features.\n",
"15. 2021 Mazda CX-5 Touring, Brown: The Mazda CX-5 is a stylish and sporty SUV with a responsive handling and a comfortable ride. This Touring model is equipped with a variety of modern features to enhance your driving experience.\n",
"16. 2023 Nissan Rogue SL, Silver: The Nissan Rogue SL is a top-of-the-line SUV with a luxurious interior and a range of advanced features. It's a great choice for those seeking a spacious and comfortable vehicle with all the bells and whistles.\n",
"17. 2019 Dodge Durango RTT, Blue: The Dodge Durango RTT is a powerful and spacious SUV with a range of performance-oriented features. It's a great choice for those seeking a vehicle that can handle both everyday driving and adventurous off-road excursions.\n",
"18. 2020 GMC Acadia SLE, White: The GMC Acadia is a spacious and versatile SUV. This SLE model is equipped with a variety of modern features to enhance your driving experience.\n",
"19. 2018 Jeep Wrangler Sahara, Green: The Jeep Wrangler is an iconic off-road vehicle with a removable top and a rugged design. This Sahara model is equipped with a range of features that make it comfortable for everyday driving as well as off-road adventures.\n",
"20. 2021 Ford Bronco Sport Badlands, Orange: The Ford Bronco Sport Badlands is a rugged and capable SUV designed for off-road adventures. It's equipped with a range of features that make it ready to tackle any terrain.\n",
"21. 2022 Toyota 4Runner TRD Off-Road, Gray: The Toyota 4Runner is a legendary off-road vehicle with a rugged design and a powerful engine. This TRD Off-Road model is equipped with a range of features that make it ready for any adventure.\n",
"22. 2020 Chevrolet Tahoe LT, Black: The Chevrolet Tahoe is a spacious and versatile SUV with a comfortable interior and a powerful engine. This LT model is equipped with a range of features that make it a great choice for families or adventurers.\n",
"23. 2019 GMC Yukon Denali, White: The GMC Yukon Denali is a luxurious and powerful SUV with a spacious interior and a range of premium features. It's a great choice for those seeking a comfortable and stylish vehicle with plenty of power.\n",
"24. 2021 Cadillac Escalade Premium Luxury, Blue: The Cadillac Escalade is a flagship SUV with a luxurious interior and a powerful engine. This Premium Luxury model is equipped with a range of advanced features that make it a true statement of luxury and style.\n",
"25. 2022 Lincoln Navigator Reserve, Black: The Lincoln Navigator is a luxurious and spacious SUV. This Reserve model is equipped with a range of premium features that make it a great choice for those seeking a refined and comfortable driving experience.\n",
"26. 2023 Lexus LX 600 Ultra Luxury, Brown: The Lexus LX 600 Ultra Luxury is a top-of-the-line SUV with a luxurious interior and a range of advanced features. It's a great choice for those seeking a vehicle that combines opulence, technology, and off-road capability.\n",
"27. 2020 BMW X5 xDrive40i, Silver: The BMW X5 is a performance-oriented SUV with a powerful engine and agile handling. This xDrive40i model is equipped with a range of features that make it a great choice for those seeking a luxurious and sporty driving experience.\n",
"28. 2019 Mercedes-Benz GLE 450 4MATIC, White: The Mercedes-Benz GLE is a stylish and luxurious SUV with a comfortable interior and a powerful engine. This 450 4MATIC model is equipped with a range of advanced features that make it a great choice for those seeking a refined and comfortable driving experience.\n",
"29. 2021 Audi Q7 Premium Plus, Blue: The Audi Q7 is a spacious and luxurious SUV with a comfortable interior and a powerful engine. This Premium Plus model is equipped with a range of advanced features that make it a great choice for those seeking a refined and comfortable driving experience.\n",
"30. 2023 Porsche Cayenne E-Hybrid, Red: The Porsche Cayenne is a performance-oriented SUV with a powerful engine and agile handling. This E-Hybrid model is equipped with a range of features that make it a great choice for those seeking a luxurious and sporty driving experience.\n",
"31. 2022 Land Rover Range Rover Velar R-Dynamic SE, Black: The Land Rover Range Rover Velar is a stylish and luxurious SUV with a comfortable interior and a powerful engine. This R-Dynamic SE model is equipped with a range of advanced features that make it a great choice for those seeking a refined and comfortable driving experience.\n",
"32. 2020 Tesla Model Y Long Range, White: The Tesla Model Y is a popular electric SUV with a spacious interior, impressive range, and advanced technology features. It's a great choice for those seeking a sustainable and innovative vehicle with plenty of space and comfort.\n",
"33. 2018 Nissan Leaf SL Plus, Gray: The Nissan Leaf is a popular electric car with a comfortable ride, spacious interior, and impressive fuel efficiency. This SL Plus model is equipped with a range of features that make it a great choice for those seeking a practical and affordable electric vehicle.\n",
"34. 2021 Chevrolet Bolt EV Premier, Blue: The Chevrolet Bolt EV is a popular electric car with a fun driving experience, spacious interior, and impressive range. This Premier model is equipped with a range of features that make it a great choice for those seeking a practical and affordable electric vehicle.\n",
"35. 2023 Hyundai Kona Electric Limited, Red: The Hyundai Kona Electric is a stylish and practical electric SUV. This Limited model is equipped with a range of features that make it a great choice for those seeking a sustainable and stylish vehicle.\n",
"36. 2022 Kia Niro EV EX Premium, Black: The Kia Niro EV is a popular electric car with a comfortable ride, spacious interior, and impressive fuel efficiency. This EX Premium model is equipped with a range of features that make it a great choice for those seeking a practical and affordable electric vehicle.\n",
"37. 2020 Volkswagen ID.4 Pro, Silver: The Volkswagen ID.4 is a popular electric SUV with a spacious interior, modern features, and impressive range. It's a great choice for those seeking a sustainable and innovative vehicle with plenty of space and comfort.\n",
"38. 2019 Tesla Model 3 Long Range, White: The Tesla Model 3 is a popular electric sedan with a sleek design, impressive range, and advanced technology features. It's a great choice for those seeking a sustainable and innovative vehicle with a sporty driving experience.\n",
"39. 2021 Ford Mustang Mach-E Premium, Red: The Ford Mustang Mach-E is a popular electric SUV with a powerful performance, stylish design, and modern features. It's a great choice for those seeking a sustainable and sporty vehicle with plenty of space and comfort.\n",
"40. 2023 Hyundai IONIQ 5 SE, Blue: The Hyundai IONIQ 5 is a popular electric car with a spacious interior, modern features, and impressive range. It's a great choice for those seeking a sustainable and innovative vehicle with plenty of space and comfort.\n",
"41. 2022 Kia EV6 GT-Line, Black: The Kia EV6 is a popular electric SUV with a stylish exterior, modern interior, and impressive range. It's a great choice for those seeking a sustainable and innovative vehicle with plenty of space and comfort.\n",
"42. 2020 Nissan Versa S, White: The Nissan Versa is a budget-friendly sedan that offers a surprising amount of interior space and fuel efficiency. This S model is a great choice for those seeking a practical and affordable vehicle for everyday commuting.\n",
"43. 2019 Mitsubishi Mirage ES, Red: The Mitsubishi Mirage is a compact car known for its fuel efficiency and affordability. This ES model offers a few extra features and a bit more style to make your daily commute more enjoyable.\n",
"44. 2021 Chevrolet Spark LS, Black: The Chevrolet Spark is a small and nimble city car that's easy to park and maneuver in tight spaces. This LS model is a great choice for those seeking an affordable and practical vehicle for city driving.\n",
"45. 2022 Kia Rio LX, Gray: The Kia Rio is a stylish and affordable subcompact car with a surprisingly spacious interior. This LX model is a great choice for those seeking a practical and stylish vehicle for everyday commuting.\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "a5eYP2UBHsH9"
},
"outputs": [],
"source": [
"__START_TURN_USER__ = \"<start_of_turn>user\\n\"\n",
"__START_TURN_MODEL__ = \"<start_of_turn>model\\n\"\n",
"__END_TURN__ = \"<end_of_turn>\\n\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yOTBqH3nA0xX"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prompt has 2716 tokens\n",
"====================\n",
"Gemma output: The discount code is 1234123.\n",
"====================\n",
"Actual discount code is 1234123\n",
"\n",
"\n",
"\n"
]
}
],
"source": [
"prompt_postfix = (\n",
" \"\\nWhat is the discount code for the 2024 Google IO Car?\"\n",
" + __END_TURN__\n",
" + __START_TURN_MODEL__\n",
" + \" The discount code is \"\n",
")\n",
"prompt = __START_TURN_USER__ + retrieval_car + car_dealership_data * 1 + prompt_postfix\n",
"tokenized = gemma_lm.preprocessor.tokenizer.tokenize(prompt)\n",
"print(f\"Prompt has {len(tokenized)} tokens\")\n",
"gemma_page = gemma_lm.generate(prompt, max_length=len(tokenized) + 20)\n",
"gemma_page = gemma_page.split(__START_TURN_MODEL__)[1]\n",
"print(\"=\" * 20)\n",
"print(\"Gemma output:\", gemma_page)\n",
"print(\"=\" * 20)\n",
"print(f\"Actual discount code is 1234123\")\n",
"print(\"\\n\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EQXvY2UDLsyj"
},
"source": [
"# Create self-extend attention call"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "znfYstRD9m6U"
},
"outputs": [],
"source": [
"from keras import ops\n",
"\n",
"GROUP_SIZE = 8\n",
"WINDOW_SIZE = 1024\n",
"\n",
"\n",
"def extend_compute_attention(\n",
" self,\n",
" query,\n",
" key,\n",
" value,\n",
" attention_mask,\n",
" training=False,\n",
"):\n",
" # Shapes for attention.\n",
" head_dim, kv_heads = self.head_dim, self.num_key_value_heads\n",
" batch_size, query_len = ops.shape(query)[:2]\n",
" key_len = ops.shape(key)[1]\n",
"\n",
" def apply_rope(inputs, positions):\n",
" outputs = self.rope_layer(inputs, positions=positions)\n",
" # Gemma has a different shape for Rope outputs\n",
" outputs = ops.stack(ops.split(outputs, 2, axis=-1), axis=-1)\n",
" return ops.reshape(outputs, ops.shape(inputs))\n",
"\n",
" def attn(query, key):\n",
" query *= ops.cast(1 / ops.sqrt(head_dim), dtype=query.dtype)\n",
" new_shape = (batch_size, query_len, kv_heads, -1, head_dim)\n",
" query = ops.reshape(query, new_shape)\n",
" return ops.einsum(\"btkgh,bskh->bkgts\", query, key)\n",
"\n",
" # Calculate normal positions.\n",
" cache_index = ops.sum(attention_mask[0, 0, :]) - 1\n",
" query_positions = ops.arange(query_len) + cache_index\n",
" key_positions = ops.arange(key_len)\n",
"\n",
" # Normal RoPE and compute logits.\n",
" local_query = apply_rope(query, query_positions)\n",
" local_key = apply_rope(key, key_positions)\n",
" local_logits = attn(local_query, local_key)\n",
"\n",
" # Calculate grouped positions.\n",
" query_positions = query_positions // GROUP_SIZE\n",
" query_positions += WINDOW_SIZE - GROUP_SIZE // WINDOW_SIZE\n",
" key_positions = key_positions // GROUP_SIZE\n",
"\n",
" # Grouped RoPE and compute logits.\n",
" grouped_query = apply_rope(query, query_positions)\n",
" grouped_key = apply_rope(key, key_positions)\n",
" grouped_logits = attn(grouped_query, grouped_key)\n",
"\n",
" # Combine attention logits.\n",
" attn_mask = attention_mask[:, None, None, :, :]\n",
" # Flip the attn mask, cumsum, and flip back to get relative positions.\n",
" group_mask = ops.flip(ops.cumsum(ops.flip(attn_mask, -1), -1), -1)\n",
" # Take local attention where relative postions are less than window size.\n",
" local_mask = group_mask < WINDOW_SIZE\n",
" logits = ops.where(local_mask, local_logits, grouped_logits)\n",
"\n",
" # Softmax and weighted value sum.\n",
" scores = ops.cast(self.softmax(logits, attn_mask), self.compute_dtype)\n",
" results = ops.einsum(\"bkgts,bskh->btkgh\", scores, value)\n",
" return ops.reshape(results, (batch_size, query_len, -1, head_dim))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2jnQoTupaFoF"
},
"source": [
"# Patching code"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1RtCV0dhaIKZ"
},
"outputs": [],
"source": [
"from types import MethodType\n",
"\n",
"for layer in gemma_lm.backbone.transformer_layers:\n",
" # Patch model rope call's with identity\n",
" # Our rope call's will happen inside of our attention computation\n",
" layer.attention._apply_rope = MethodType(lambda s, x, i: x, layer.attention)\n",
" # Patching model attention with self_extend version\n",
" layer.attention._compute_attention = MethodType(\n",
" extend_compute_attention, layer.attention\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A9mh1itbq3uF"
},
"outputs": [],
"source": [
"gemma_lm.built = False\n",
"gemma_lm.generate_function = None\n",
"keras.config.disable_traceback_filtering()\n",
"gemma_lm.compile(sampler=keras_nlp.samplers.GreedySampler())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KlIsOfcykYwN"
},
"outputs": [
{
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"\"<start_of_turn>user\\nAnswer me: Hi there<end_of_turn>\\n<start_of_turn>model\\nHi there! 👋 It's great to hear from you. What would you like to talk about today?\""
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gemma_lm.generate(\n",
" __START_TURN_USER__ + \"Answer me: Hi there\" + __END_TURN__ + __START_TURN_MODEL__,\n",
" max_length=80,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xaRDrieuUA6a"
},
"source": [
"# Test with SelfExtend Algorithm - 10k context window"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9O15mj27Soaq"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prompt has 10639 tokens\n",
"====================\n",
"Gemma output: The discount code is 1234123\n",
"====================\n",
"Actual discount code is 1234123\n",
"\n",
"\n",
"\n"
]
}
],
"source": [
"prompt_postfix = (\n",
" \"\\nWhat is the discount code for the 2024 Google IO Car?\"\n",
" + __END_TURN__\n",
" + __START_TURN_MODEL__\n",
" + \" The discount code is \"\n",
")\n",
"prompt = __START_TURN_USER__ + retrieval_car + car_dealership_data * 4 + prompt_postfix\n",
"tokenized = gemma_lm.preprocessor.tokenizer.tokenize(prompt)\n",
"print(f\"Prompt has {len(tokenized)} tokens\")\n",
"gemma_page = gemma_lm.generate(prompt, max_length=len(tokenized) + 20)\n",
"gemma_page = gemma_page.split(__START_TURN_MODEL__)[1].split(\".\")[0]\n",
"print(\"=\" * 20)\n",
"print(\"Gemma output:\", gemma_page)\n",
"print(\"=\" * 20)\n",
"print(f\"Actual discount code is 1234123\")\n",
"print(\"\\n\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bzKsCGIN0yX5"
},
"source": [
"# What's next\n",
"\n",
"In this tutorial, you learned how to extend the context window of Gemma, using Keras on JAX.\n",
"\n",
"Here are a few suggestions for what else to learn, about Keras and JAX:\n",
"* [Distributed training with Keras 3](https://keras.io/guides/distribution/).\n",
"* [Writing a custom training loop for a Keras model in JAX](https://keras.io/guides/writing_a_custom_training_loop_in_jax/).\n",
"\n",
"And a couple of more basic Gemma tutorials:\n",
"\n",
"* [Get started with Keras Gemma](https://ai.google.dev/gemma/docs/get_started).\n",
"* [Finetune the Gemma model on GPU](https://ai.google.dev/gemma/docs/lora_tuning)."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "[Gemma_1]Self_extend.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}